import numpy as np
import PIL.Image
import re
import torch
from torch.utils.data import Dataset
from transformers import T5TokenizerFast, T5ForConditionalGeneration, T5Tokenizer


def convert_propbank_to_readable_and_split(input_text, role_mapping):
    def replace_role(match):
        role = match.group(1)
        if role in role_mapping:
            return f"[{role_mapping[role]}:"
        elif role == "V":  # Handle the verb case
            return "[Verb:"
        else:
            return match.group(0)

    pattern = r"\[(ARG\d+|ARGM-[A-Z]+|V):"
    converted_text = re.sub(pattern, replace_role, input_text)
    sentences = [sentence.strip() for sentence in converted_text.split('.') if sentence.strip()]

    return sentences



def extract_background_color_style(input_text_list):
     style_list= ["No style"] * len(input_text_list)
     for i in range(len(input_text_list)):
        input_text = input_text_list[i]
        lines = input_text.strip().split('\n')
        for line in lines:
            if "### Background color style:" in line:
                style = line.split("### Background color style:")[1].strip()
                style_list[i] = style
                break

     return style_list


def combine_propbank_description(new_propbank_obj1_des_list, new_propbank_obj2_des_list):
    role_mapping = {
        "ARG0": "Agent",
        "ARG1": "Patient/Theme",
        "ARG2": "Instrument/Benefactive/Attribute",
        "ARGM-ADV": "General purpose",
        "ARGM-CAU": "Cause",
        "ARGM-DIR": "Direction",
        "ARGM-DIS": "Discourse marker",
        "ARGM-EXT": "Extent",
        "ARGM-LOC": "Location",
        "ARGM-MNR": "Manner",
        "ARGM-MOD": "Modal verb",
        "ARGM-NEG": "Negation",
        "ARGM-PNC": "Purpose",
        "ARGM-PRD": "Predication",
        "ARGM-REC": "Reciprocal",
        "ARGM-TMP": "Temporal",
        "ARGM-PRP": "Purpose",
        "ARGM-COM": "Comitatives"}

    combined_res = []
    for i in range(len(new_propbank_obj1_des_list)):
        res1 = convert_propbank_to_readable_and_split(new_propbank_obj1_des_list[i], role_mapping)
        res2 = convert_propbank_to_readable_and_split(new_propbank_obj2_des_list[i], role_mapping)
        temp = res1 + res2
        if len(temp) > 4:
            temp = temp[:4]
        combined_res.append(temp)

    # print(len(combined_res))

    return combined_res


def extract_relative_position(input_text_list):
    position_list = ["No position"] * len(input_text_list)
    for i in range(len(input_text_list)):

        pattern = r"### Relative Position: ([A-Z]+)"
        match = re.search(pattern, input_text_list[i])

        if match:
            position_list[i] = match.group(1)
    return position_list




def remove_error(fmri, description_list, position_list, propbank_obj1_des, propbank_obj2_des):
    """
    Remove error descriptions based on specific criteria.
    """
    new_fmri_list = []
    new_description_list = []
    new_position_list = []
    new_propbank_obj1_des = []
    new_propbank_obj2_des = []

    for i in range(len(description_list)):
        if propbank_obj1_des[i] != "Error":
            new_fmri_list.append(fmri[i])
            new_description_list.append(description_list[i])
            new_position_list.append(position_list[i])
            new_propbank_obj1_des.append(propbank_obj1_des[i])
            new_propbank_obj2_des.append(propbank_obj2_des[i])

    print(len(new_description_list), len(new_position_list), len(new_propbank_obj1_des), len(new_propbank_obj2_des))
    return new_fmri_list, new_description_list, new_position_list, new_propbank_obj1_des, new_propbank_obj2_des

def concat_pos_color(fmri, position_list, color_style_list, combined_propbank):
    position_map = {
        'LEFT': 'LEFT',
        'RIGHT': 'RIGHT',
        'TOP': 'TOP',
        'BOTTOM': 'BOTTOM',
        'BELOW': 'BOTTOM',
        'ABOVE': 'TOP',
        'AROUND': 'TOP',
        'ON': 'TOP',
        'CENTER': 'TOP',
        'FRONT': 'TOP',
        'INSIDE': 'TOP',
        'IN': 'TOP',
        'BEHIND': 'TOP',
        'No position': 'LEFT',
        'BACKGROUND': 'TOP',
        'BACK': 'TOP',
        'NONE': 'LEFT'
    }
    offset_dict = {
        'LEFT': 'RIGHT',
        'RIGHT': 'LEFT',
        'TOP': 'BOTTOM',
        'BOTTOM': 'TOP',
    }

    fmri_res = []
    combined_propbank_res = []

    for i in range(len(combined_propbank)):
        temp_position = position_list[i]
        temp_color = color_style_list[i]
        obj1_temp_position = position_map[temp_position]
        obj2_temp_position = offset_dict[obj1_temp_position]
        try:
            # obj1
            combined_propbank[i][0] = combined_propbank[i][0] +f"[Postion: {obj1_temp_position}]"  + f"[Background color: {temp_color}]"
            combined_propbank[i][1] = combined_propbank[i][1] +f"[Postion: {obj1_temp_position}]"  + f"[Background color: {temp_color}]"
            # obj2
            combined_propbank[i][2] = combined_propbank[i][2] +f"[Postion: {obj2_temp_position}]"  + f"[Background color: {temp_color}]"
            combined_propbank[i][3] = combined_propbank[i][3] +f"[Postion: {obj2_temp_position}]"  + f"[Background color: {temp_color}]"
            fmri_res.append(fmri[i])
            combined_propbank_res.append(combined_propbank[i])
        except:
            print("Error in concatenating position and color style")
            print(f"Index: {i}, Position: {temp_position}, Color Style: {temp_color}")
            print(f"Combined Propbank: {combined_propbank[i]}")
            continue
    print(len(fmri_res), len(combined_propbank_res))
    return fmri_res, combined_propbank_res




def load_data_lists(split, save_dir):
    if split == "train":

        data = np.load(f'{save_dir}/{split}.npz', allow_pickle=True)
        return (
            data['fmri'],
            data['images'],
            data['captions'].tolist() if len(data['captions']) > 0 else [],
            data['coco_ids'].tolist(),
            data['responses'].tolist()
        )
    else:
        data = np.load(f'{save_dir}/{split}.npz', allow_pickle=True)

        return (
            data['fmri'],
            [PIL.Image.fromarray(img) for img in data['images']],
            data['captions'].tolist() if len(data['captions']) > 0 else [],
            data['coco_ids'].tolist(),
            data['responses'].tolist()
        )

class FMRITextDataset(Dataset):
    def __init__(self, fmri_vectors, text_descriptions, tokenizer, max_length=128):

        self.fmri_vectors = fmri_vectors  # shape: [num_samples, 15724]
        self.text_descriptions = text_descriptions  # list of propbank decorated strings

        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.fmri_vectors)

    def __getitem__(self, idx):
        fmri = torch.tensor(self.fmri_vectors[idx], dtype=torch.float)

        sentences = self.text_descriptions[idx]
        tokenized_sentences = []

        for sentence in sentences:
            tokenized_sentence = self.tokenizer(
                sentence,
                max_length=self.max_length,
                padding="max_length",
                truncation=True,
                return_tensors="pt"
            )
            tokenized_sentences.append(tokenized_sentence.input_ids.squeeze(0))


        return {
            "fmri": fmri,
            "labels": tokenized_sentences,
            "text": sentences
        }



def load_all_data():

    split = "train"
    save_dir = "data/train"

    train_data = np.load("data.npz", allow_pickle=True)
    fmri_train = train_data['train_fmri']

    img_list_train = train_data['img_list_train']
    print(len(img_list_train))
    test_data = np.load("data_test.npz", allow_pickle=True)


    # train
    train_description = np.load(f'{save_dir}/description.npz', allow_pickle=True)
    train_description = train_description['response_list']
    train_position = np.load(f'{save_dir}/description_position.npz', allow_pickle=True)
    train_position = train_position['response_list']
    train_propbank =np.load(f'{save_dir}/propbank_des.npz', allow_pickle=True)
    propbank_obj1_des = train_propbank['propbank_obj1_des']
    propbank_obj2_des = train_propbank['propbank_obj2_des']
    new_fmri_train, new_description_list_train, new_position_list_train, new_propbank_obj1_des_train, new_propbank_obj2_des_train = remove_error(fmri_train, train_description, train_position, propbank_obj1_des, propbank_obj2_des)

    position_list_train = extract_relative_position(new_position_list_train)
    color_style_list_train = extract_background_color_style(new_description_list_train)
    combined_propbank_train = combine_propbank_description(new_propbank_obj1_des_train, new_propbank_obj2_des_train)
    fmri_train_final, combined_propbank_train_final = concat_pos_color(new_fmri_train, position_list_train, color_style_list_train, combined_propbank_train)
    print("!!!!!!", len(fmri_train_final), len(combined_propbank_train_final))

    # test
    split = "test"
    save_dir = "data/test"

    data = load_data_lists(split, save_dir)
    fmri_test = data[0]

    test_description = np.load(f'{save_dir}description.npz', allow_pickle=True)
    test_description = test_description['response_list']
    test_position = np.load(f'{save_dir}/description_position.npz', allow_pickle=True)
    test_position = test_position['response_list']
    test_propbank =np.load(f'{save_dir}/propbank_des.npz', allow_pickle=True)
    propbank_obj1_des_test = test_propbank['propbank_obj1_des']
    propbank_obj2_des_test = test_propbank['propbank_obj2_des']
    new_fmri_test, new_description_list_test, new_position_list_test, new_propbank_obj1_des_test, new_propbank_obj2_des_test = remove_error(fmri_test, test_description, test_position, propbank_obj1_des_test, propbank_obj2_des_test)
    position_list_test = extract_relative_position(new_position_list_test)
    color_style_list_test = extract_background_color_style(new_description_list_test)
    combined_propbank_test = combine_propbank_description(new_propbank_obj1_des_test, new_propbank_obj2_des_test)
    fmri_test_final, combined_propbank_test_final = concat_pos_color(new_fmri_test, position_list_test, color_style_list_test, combined_propbank_test)
    print("!!!!!!", len(fmri_test_final), len(combined_propbank_test_final))

    tokenizer = T5TokenizerFast.from_pretrained("t5-base")

    train_data = FMRITextDataset(fmri_train_final, combined_propbank_train_final, tokenizer)
    test_data = FMRITextDataset(fmri_test_final, combined_propbank_test_final, tokenizer)

    return train_data, test_data



# load_all_data()
